Machine learning to segment neutron images
Anders Kaestner, Beamline scientist - Neutron Imaging
Laboratory for Neutron Scattering and Imaging
Paul Scherrer Institut
This lecture needs some modules to run. We import all of them here.
import matplotlib.pyplot as plt
import seaborn as sn
import numpy as np
import pandas as pd
import skimage.filters as flt
import skimage.io as io
import matplotlib as mpl
from sklearn.cluster import KMeans
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix
from sklearn.datasets import make_blobs
from matplotlib.colors import ListedColormap
from lecturesupport import plotsupport as ps
import scipy.stats as stats
import astropy.io.fits as fits
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'png')
#plt.style.use('seaborn')
mpl.rcParams['figure.dpi'] = 300
Using TensorFlow backend.
import importlib
importlib.reload(ps);
Introduction to neutron imaging
Introduction to segmentation
Problematic segmentation tasks
A very abstract definition:
In most cases this is a two- or three-dimensional position (x,y,z coordinates) and a numeric value (intensity)
Images are great for qualitative analyses since our brains can quickly interpret them without large programming investements.
| Transmission through sample | X-ray attenuation | Neutron attenuation |
Start out with a simple image of a cross with added noise
$$ I(x,y) = f(x,y) $$fig,ax = plt.subplots(1,2,figsize=(12,6))
nx = 5; ny = 5;
xx, yy = np.meshgrid(np.arange(-nx, nx+1)/nx*2*np.pi, np.arange(-ny, ny+1)/ny*2*np.pi)
cross_im = 1.5*np.abs(np.cos(xx*yy))/(np.abs(xx*yy)+(3*np.pi/nx)) + np.random.uniform(-0.25, 0.25, size = xx.shape)
im=ax[0].imshow(cross_im, cmap = 'hot'); ax[0].set_title("Image")
ax[1].hist(cross_im.ravel(),bins=10); ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
Applying the threshold is a deceptively simple operation
$$ I(x,y) = \begin{cases} 1, & f(x,y)\geq0.40 \\ 0, & f(x,y)<0.40 \end{cases}$$threshold = 0.4; thresh_img = cross_im > threshold
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0].imshow(cross_im, cmap = 'hot', extent = [xx.min(), xx.max(), yy.min(), yy.max()]); ax[0].set_title("Image")
ax[0].plot(xx[np.where(thresh_img)]*0.9, yy[np.where(thresh_img)]*0.9,
'ks', markerfacecolor = 'green', alpha = 0.5,label = 'Threshold', markersize = 22); ax[0].legend(fontsize=12);
ax[1].hist(cross_im.ravel(),bins=10); ax[1].axvline(x=threshold,color='r',label='Threshold'); ax[1].legend(fontsize=12);
ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
The noise in neutron imaging mainly originates from the amount of captured neutrons.
This noise is Poisson distributed and the signal to noise ratio is
$$SNR=\frac{E[x]}{s[x]}\sim\frac{N}{\sqrt{N}}=\sqrt{N}$$
Woodland Encounter Bev Doolittle
Different types of limited data:
test_pts = pd.DataFrame(make_blobs(n_samples=200, random_state=2018)[
0], columns=['x', 'y'])
plt.plot(test_pts.x, test_pts.y, 'r.');
fig, ax = plt.subplots(1,3,figsize=(15,4.5))
for i in range(3) :
km = KMeans(n_clusters=i+2, random_state=2018); n_grp = km.fit_predict(test_pts)
ax[i].scatter(test_pts.x, test_pts.y, c=n_grp)
ax[i].set_title('{0} groups'.format(i+2))
tof = np.load('../data/tofdata.npy')
wtof = tof.mean(axis=2)
plt.imshow(wtof);
tofr=tof.reshape([tof.shape[0]*tof.shape[1],tof.shape[2]])
print("Input ToF dimensions",tof.shape)
print("Reshaped ToF data",tofr.shape)
Input ToF dimensions (128, 128, 661) Reshaped ToF data (16384, 661)
km = KMeans(n_clusters=4, random_state=2018)
c = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose() # cluster centroid spectra
Results from the first try
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot
im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
km = KMeans(n_clusters=10, random_state=2018)
c = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose() # cluster centroid spectra
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot
im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
fig,axes = plt.subplots(1,2,figsize=(14,5)); axes=axes.ravel()
axes[0].matshow(np.corrcoef(kc.transpose()))
axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
del km, c, kc, tofr, tof
blob_data, blob_labels = make_blobs(n_samples=100, random_state=2018)
test_pts = pd.DataFrame(blob_data, columns=['x', 'y'])
test_pts['group_id'] = blob_labels
plt.scatter(test_pts.x, test_pts.y, c=test_pts.group_id, cmap='viridis');
orig= fits.getdata('../data/spots/mixture12_00001.fits')
annotated=io.imread('../data/spots/mixture12_00001.png'); mask=(annotated[:,:,1]==0)
r=600; c=600; w=256
ps.magnifyRegion(orig,[r,c,r+w,c+w],[15,7],vmin=400,vmax=4000,title='Neutron radiography')
Parameters
def spotCleaner(img, threshold=0.95, selem=np.ones([3,3])) :
fimg=img.astype('float32')
mimg = flt.median(fimg,selem=selem)
timg = threshold < np.abs(fimg-mimg)
cleaned = mimg * timg + fimg * (1-timg)
return (cleaned,timg)
baseclean,timg = spotCleaner(orig,threshold=1000)
ps.magnifyRegion(baseclean,[r,c,r+w,c+w],[12,3],vmin=400,vmax=4000,title='Cleaned image')
ps.magnifyRegion(timg,[r,c,r+w,c+w],[12,3],vmin=0,vmax=1,title='Detection image')
selem=np.ones([3,3])
forig=orig.astype('float32')
mimg = flt.median(forig,selem=selem)
d = np.abs(forig-mimg)
fig,ax=plt.subplots(1,1,figsize=(8,5))
h,x,y,u=ax.hist2d(forig[:1024,:].ravel(),d[:1024,:].ravel(), bins=100);
ax.imshow(np.log(h[::-1]+1),vmin=0,vmax=3,extent=[x.min(),x.max(),y.min(),y.max()])
ax.set_xlabel('Input image - $f$'),ax.set_ylabel('$|f-med_{3x3}(f)|$'),ax.set_title('Log bivariate histogram');
Training data
trainorig = forig[:,:1000].ravel()
traind = d[:,:1000].ravel()
trainmask = mask[:,:1000].ravel()
train_pts = pd.DataFrame({'orig': trainorig, 'd': traind, 'mask':trainmask})
Test data
testorig = forig[:,1000:].ravel()
testd = d[:,1000:].ravel()
testmask = mask[:,1000:].ravel()
test_pts = pd.DataFrame({'orig': testorig, 'd': testd, 'mask':testmask})
k_class = KNeighborsClassifier(1)
k_class.fit(train_pts[['orig', 'd']], train_pts['mask'])
KNeighborsClassifier(n_neighbors=1)
Inspect decision space
xx, yy = np.meshgrid(np.linspace(test_pts.orig.min(), test_pts.orig.max(), 100),
np.linspace(test_pts.d.min(), test_pts.d.max(), 100),indexing='ij');
grid_pts = pd.DataFrame(dict(x=xx.ravel(), y=yy.ravel()))
grid_pts['predicted_id'] = k_class.predict(grid_pts[['x', 'y']])
plt.scatter(grid_pts.x, grid_pts.y, c=grid_pts.predicted_id, cmap='gray'); plt.title('Testing Points'); plt.axis('square');
pred = k_class.predict(test_pts[['orig', 'd']])
pimg = pred.reshape(d[1000:,:].shape)
fig,ax = plt.subplots(1,3,figsize=(15,6))
ax[0].imshow(forig[1000:,:],vmin=0,vmax=4000), ax[0].set_title('Original image')
ax[1].imshow(pimg), ax[1].set_title('Predicted spot')
ax[2].imshow(mask[1000:,:]),ax[2].set_title('Annotated spots');
cmbase = confusion_matrix(mask[:,1000:].ravel(), timg[:,1000:].ravel(), normalize='all')
cmknn = confusion_matrix(mask[:,1000:].ravel(), pimg.ravel(), normalize='all')
fig,ax = plt.subplots(1,2,figsize=(10,4))
sn.heatmap(cmbase, annot=True,ax=ax[0]), ax[0].set_title('Confusion matrix baseline');
sn.heatmap(cmknn, annot=True,ax=ax[1]), ax[1].set_title('Confusion matrix k-NN');
Note There are other spot detection methods that perform better than the baseline.
del k_class, cmbase, cmknn
import keras.optimizers as opt
import keras.losses as loss
import keras.metrics as metrics
We have two choices:
We will use the spotty image as training data for this example
Any analysis system must be verified to be demonstrate its performance and to further optimize it.
For this we need to split our data into three categories:
| Training | Validation | Test |
|---|---|---|
| 70% | 15% | 15% |
def buildSpotUNet( base_depth = 48) :
in_img = Input((None, None, 1), name='Image_Input')
lay_1 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(in_img)
lay_2 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_1)
lay_3 = MaxPooling2D(pool_size=(2, 2))(lay_2)
lay_4 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_3)
lay_5 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_4)
lay_6 = MaxPooling2D(pool_size=(2, 2))(lay_5)
lay_7 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_6)
lay_8 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_7)
lay_9 = UpSampling2D((2, 2))(lay_8)
lay_10 = concatenate([lay_5, lay_9])
lay_11 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_10)
lay_12 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_11)
lay_13 = UpSampling2D((2, 2))(lay_12)
lay_14 = concatenate([lay_2, lay_13])
lay_15 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_14)
lay_16 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_15)
lay_17 = Conv2D(1, kernel_size=(1, 1), padding='same',
activation='relu')(lay_16)
t_unet = Model(inputs=[in_img], outputs=[lay_17], name='SpotUNET')
return t_unet
Model summary
t_unet = buildSpotUNet(base_depth=24)
t_unet.summary()
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.
Model: "SpotUNET"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Image_Input (InputLayer) (None, None, None, 1 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, None, None, 2 240 Image_Input[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, None, None, 2 5208 conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, None, None, 2 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, None, None, 4 10416 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, None, None, 4 20784 conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, None, None, 4 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, None, None, 9 41568 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, None, None, 9 83040 conv2d_5[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, None, None, 9 0 conv2d_6[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, None, None, 1 0 conv2d_4[0][0]
up_sampling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, None, None, 4 62256 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, None, None, 4 20784 conv2d_7[0][0]
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, None, None, 4 0 conv2d_8[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, None, None, 7 0 conv2d_2[0][0]
up_sampling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, None, None, 2 15576 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, None, None, 2 5208 conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, None, None, 1 25 conv2d_10[0][0]
==================================================================================================
Total params: 265,105
Trainable params: 265,105
Non-trainable params: 0
__________________________________________________________________________________________________
train_img, valid_img = forig[128:256, 500:1300], forig[500:1000, 300:1500]
train_mask, valid_mask = mask[128:256, 500:1300], mask[500:1000, 300:1500]
wpos = [600,600]; ww = 512
forigc = forig[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
maskc = mask[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
# train_img, valid_img = forig[128:256, 300:1500], forig[500:, 300:1500]
# train_mask, valid_mask = mask[128:256, 300:1500], mask[500:, 300:1500]
fig, ax = plt.subplots(1, 4, figsize=(15, 6), dpi=300); ax=ax.ravel()
ax[0].imshow(train_img, cmap='bone',vmin=0,vmax=4000);ax[0].set_title('Train Image')
ax[1].imshow(train_mask, cmap='bone'); ax[1].set_title('Train Mask')
ax[2].imshow(valid_img, cmap='bone',vmin=0,vmax=4000); ax[2].set_title('Validation Image')
ax[3].imshow(valid_mask, cmap='bone');ax[3].set_title('Validation Mask');
def prep_img(x, n=1):
return (prep_mask(x, n=n)-train_img.mean())/train_img.std()
def prep_mask(x, n=1):
return np.stack([np.expand_dims(x, -1)]*n, 0)
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
fig, m_axs = plt.subplots(2, 3, figsize=(15, 6), dpi=150)
for c_ax in m_axs.ravel():
c_ax.axis('off')
((ax1, _, ax2), (ax3, ax4, ax5)) = m_axs
ax1.imshow(train_img, cmap='bone',vmin=0,vmax=4000); ax1.set_title('Train Image')
ax2.imshow(train_mask, cmap='viridis'); ax2.set_title('Train Mask')
ax3.imshow(forigc, cmap='bone',vmin=0, vmax=4000); ax3.set_title('Test Image')
ax4.imshow(unet_pred, cmap='viridis', vmin=0, vmax=1); ax4.set_title('Predicted Segmentation')
ax5.imshow(maskc, cmap='viridis'); ax5.set_title('Ground Truth');
Another popular metric is the Dice score $$DSC=\frac{2|X \cap Y|}{|X|+|Y|}=\frac{2\,TP}{2TP+FP+FN}$$
mlist = [
metrics.TruePositives(name='tp'), metrics.FalsePositives(name='fp'),
metrics.TrueNegatives(name='tn'), metrics.FalseNegatives(name='fn'),
metrics.BinaryAccuracy(name='accuracy'), metrics.Precision(name='precision'),
metrics.Recall(name='recall'), metrics.AUC(name='auc'),
metrics.MeanAbsoluteError(name='mae')]
t_unet.compile(
loss=loss.BinaryCrossentropy(), # we use the binary cross-entropy to optimize
optimizer=opt.Adam(lr=1e-3), # we use ADAM to optimize
metrics=mlist # we keep track of the metrics in mlist
)
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3172: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
This is a very bad way to train a model;
The goal is to be aware of these techniques and have a feeling for how they can work for complex problems
loss_history = t_unet.fit(prep_img(train_img, n=3),
prep_mask(train_mask, n=3),
validation_data=(prep_img(valid_img),
prep_mask(valid_mask)),
epochs=20,
verbose = 1)
Train on 3 samples, validate on 1 samples Epoch 1/20 3/3 [==============================] - 11s 4s/step - loss: 0.0669 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.5612 - mae: 0.0336 - val_loss: 0.1487 - val_tp: 2.0000 - val_fp: 3.0000 - val_tn: 593513.0000 - val_fn: 6482.0000 - val_accuracy: 0.9892 - val_precision: 0.4000 - val_recall: 3.0845e-04 - val_auc: 0.5572 - val_mae: 0.0109 Epoch 2/20 3/3 [==============================] - 7s 2s/step - loss: 0.1149 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.5503 - mae: 0.0084 - val_loss: 0.0814 - val_tp: 14.0000 - val_fp: 7.0000 - val_tn: 593509.0000 - val_fn: 6470.0000 - val_accuracy: 0.9892 - val_precision: 0.6667 - val_recall: 0.0022 - val_auc: 0.7863 - val_mae: 0.0116 Epoch 3/20 3/3 [==============================] - 7s 2s/step - loss: 0.0620 - tp: 3.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2541.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0012 - auc: 0.7886 - mae: 0.0092 - val_loss: 0.0973 - val_tp: 78.0000 - val_fp: 53.0000 - val_tn: 593463.0000 - val_fn: 6406.0000 - val_accuracy: 0.9892 - val_precision: 0.5954 - val_recall: 0.0120 - val_auc: 0.7086 - val_mae: 0.0776 Epoch 4/20 3/3 [==============================] - 7s 2s/step - loss: 0.0951 - tp: 48.0000 - fp: 15.0000 - tn: 304641.0000 - fn: 2496.0000 - accuracy: 0.9918 - precision: 0.7619 - recall: 0.0189 - auc: 0.7450 - mae: 0.0800 - val_loss: 0.0711 - val_tp: 71.0000 - val_fp: 35.0000 - val_tn: 593481.0000 - val_fn: 6413.0000 - val_accuracy: 0.9892 - val_precision: 0.6698 - val_recall: 0.0110 - val_auc: 0.7740 - val_mae: 0.0493 Epoch 5/20 3/3 [==============================] - 7s 2s/step - loss: 0.0678 - tp: 42.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2502.0000 - accuracy: 0.9919 - precision: 1.0000 - recall: 0.0165 - auc: 0.7879 - mae: 0.0518 - val_loss: 0.0681 - val_tp: 44.0000 - val_fp: 12.0000 - val_tn: 593504.0000 - val_fn: 6440.0000 - val_accuracy: 0.9892 - val_precision: 0.7857 - val_recall: 0.0068 - val_auc: 0.7950 - val_mae: 0.0143 Epoch 6/20 3/3 [==============================] - 7s 2s/step - loss: 0.0507 - tp: 33.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2511.0000 - accuracy: 0.9918 - precision: 1.0000 - recall: 0.0130 - auc: 0.7877 - mae: 0.0150 - val_loss: 0.0638 - val_tp: 42.0000 - val_fp: 10.0000 - val_tn: 593506.0000 - val_fn: 6442.0000 - val_accuracy: 0.9892 - val_precision: 0.8077 - val_recall: 0.0065 - val_auc: 0.8581 - val_mae: 0.0115 Epoch 7/20 3/3 [==============================] - 7s 2s/step - loss: 0.0473 - tp: 33.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2511.0000 - accuracy: 0.9918 - precision: 1.0000 - recall: 0.0130 - auc: 0.8675 - mae: 0.0092 - val_loss: 0.0571 - val_tp: 48.0000 - val_fp: 10.0000 - val_tn: 593506.0000 - val_fn: 6436.0000 - val_accuracy: 0.9893 - val_precision: 0.8276 - val_recall: 0.0074 - val_auc: 0.8322 - val_mae: 0.0182 Epoch 8/20 3/3 [==============================] - 7s 2s/step - loss: 0.0437 - tp: 36.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2508.0000 - accuracy: 0.9918 - precision: 1.0000 - recall: 0.0142 - auc: 0.8603 - mae: 0.0141 - val_loss: 0.0515 - val_tp: 65.0000 - val_fp: 10.0000 - val_tn: 593506.0000 - val_fn: 6419.0000 - val_accuracy: 0.9893 - val_precision: 0.8667 - val_recall: 0.0100 - val_auc: 0.8736 - val_mae: 0.0171 Epoch 9/20 3/3 [==============================] - 7s 2s/step - loss: 0.0395 - tp: 42.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2502.0000 - accuracy: 0.9919 - precision: 1.0000 - recall: 0.0165 - auc: 0.8920 - mae: 0.0134 - val_loss: 0.0461 - val_tp: 79.0000 - val_fp: 14.0000 - val_tn: 593502.0000 - val_fn: 6405.0000 - val_accuracy: 0.9893 - val_precision: 0.8495 - val_recall: 0.0122 - val_auc: 0.9150 - val_mae: 0.0150 Epoch 10/20 3/3 [==============================] - 7s 2s/step - loss: 0.0356 - tp: 57.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2487.0000 - accuracy: 0.9919 - precision: 1.0000 - recall: 0.0224 - auc: 0.9258 - mae: 0.0120 - val_loss: 0.0426 - val_tp: 88.0000 - val_fp: 15.0000 - val_tn: 593501.0000 - val_fn: 6396.0000 - val_accuracy: 0.9893 - val_precision: 0.8544 - val_recall: 0.0136 - val_auc: 0.9390 - val_mae: 0.0137 Epoch 11/20 3/3 [==============================] - 7s 2s/step - loss: 0.0334 - tp: 63.0000 - fp: 3.0000 - tn: 304653.0000 - fn: 2481.0000 - accuracy: 0.9919 - precision: 0.9545 - recall: 0.0248 - auc: 0.9444 - mae: 0.0111 - val_loss: 0.0405 - val_tp: 103.0000 - val_fp: 23.0000 - val_tn: 593493.0000 - val_fn: 6381.0000 - val_accuracy: 0.9893 - val_precision: 0.8175 - val_recall: 0.0159 - val_auc: 0.9489 - val_mae: 0.0131 Epoch 12/20 3/3 [==============================] - 7s 2s/step - loss: 0.0317 - tp: 81.0000 - fp: 3.0000 - tn: 304653.0000 - fn: 2463.0000 - accuracy: 0.9920 - precision: 0.9643 - recall: 0.0318 - auc: 0.9494 - mae: 0.0108 - val_loss: 0.0392 - val_tp: 123.0000 - val_fp: 33.0000 - val_tn: 593483.0000 - val_fn: 6361.0000 - val_accuracy: 0.9893 - val_precision: 0.7885 - val_recall: 0.0190 - val_auc: 0.9549 - val_mae: 0.0127 Epoch 13/20 3/3 [==============================] - 7s 2s/step - loss: 0.0305 - tp: 90.0000 - fp: 6.0000 - tn: 304650.0000 - fn: 2454.0000 - accuracy: 0.9920 - precision: 0.9375 - recall: 0.0354 - auc: 0.9542 - mae: 0.0105 - val_loss: 0.0380 - val_tp: 147.0000 - val_fp: 39.0000 - val_tn: 593477.0000 - val_fn: 6337.0000 - val_accuracy: 0.9894 - val_precision: 0.7903 - val_recall: 0.0227 - val_auc: 0.9586 - val_mae: 0.0123 Epoch 14/20 3/3 [==============================] - 7s 2s/step - loss: 0.0293 - tp: 105.0000 - fp: 12.0000 - tn: 304644.0000 - fn: 2439.0000 - accuracy: 0.9920 - precision: 0.8974 - recall: 0.0413 - auc: 0.9561 - mae: 0.0101 - val_loss: 0.0368 - val_tp: 195.0000 - val_fp: 56.0000 - val_tn: 593460.0000 - val_fn: 6289.0000 - val_accuracy: 0.9894 - val_precision: 0.7769 - val_recall: 0.0301 - val_auc: 0.9615 - val_mae: 0.0125 Epoch 15/20 3/3 [==============================] - 7s 2s/step - loss: 0.0282 - tp: 138.0000 - fp: 12.0000 - tn: 304644.0000 - fn: 2406.0000 - accuracy: 0.9921 - precision: 0.9200 - recall: 0.0542 - auc: 0.9610 - mae: 0.0101 - val_loss: 0.0356 - val_tp: 242.0000 - val_fp: 72.0000 - val_tn: 593444.0000 - val_fn: 6242.0000 - val_accuracy: 0.9894 - val_precision: 0.7707 - val_recall: 0.0373 - val_auc: 0.9641 - val_mae: 0.0128 Epoch 16/20 3/3 [==============================] - 7s 2s/step - loss: 0.0273 - tp: 165.0000 - fp: 30.0000 - tn: 304626.0000 - fn: 2379.0000 - accuracy: 0.9921 - precision: 0.8462 - recall: 0.0649 - auc: 0.9657 - mae: 0.0104 - val_loss: 0.0346 - val_tp: 277.0000 - val_fp: 95.0000 - val_tn: 593421.0000 - val_fn: 6207.0000 - val_accuracy: 0.9895 - val_precision: 0.7446 - val_recall: 0.0427 - val_auc: 0.9666 - val_mae: 0.0128 Epoch 17/20 3/3 [==============================] - 7s 2s/step - loss: 0.0267 - tp: 186.0000 - fp: 51.0000 - tn: 304605.0000 - fn: 2358.0000 - accuracy: 0.9921 - precision: 0.7848 - recall: 0.0731 - auc: 0.9670 - mae: 0.0105 - val_loss: 0.0335 - val_tp: 334.0000 - val_fp: 118.0000 - val_tn: 593398.0000 - val_fn: 6150.0000 - val_accuracy: 0.9895 - val_precision: 0.7389 - val_recall: 0.0515 - val_auc: 0.9689 - val_mae: 0.0131 Epoch 18/20 3/3 [==============================] - 7s 2s/step - loss: 0.0260 - tp: 201.0000 - fp: 72.0000 - tn: 304584.0000 - fn: 2343.0000 - accuracy: 0.9921 - precision: 0.7363 - recall: 0.0790 - auc: 0.9699 - mae: 0.0110 - val_loss: 0.0329 - val_tp: 372.0000 - val_fp: 141.0000 - val_tn: 593375.0000 - val_fn: 6112.0000 - val_accuracy: 0.9895 - val_precision: 0.7251 - val_recall: 0.0574 - val_auc: 0.9701 - val_mae: 0.0129 Epoch 19/20 3/3 [==============================] - 7s 2s/step - loss: 0.0252 - tp: 225.0000 - fp: 81.0000 - tn: 304575.0000 - fn: 2319.0000 - accuracy: 0.9921 - precision: 0.7353 - recall: 0.0884 - auc: 0.9716 - mae: 0.0106 - val_loss: 0.0323 - val_tp: 415.0000 - val_fp: 161.0000 - val_tn: 593355.0000 - val_fn: 6069.0000 - val_accuracy: 0.9896 - val_precision: 0.7205 - val_recall: 0.0640 - val_auc: 0.9716 - val_mae: 0.0128 Epoch 20/20 3/3 [==============================] - 7s 2s/step - loss: 0.0247 - tp: 240.0000 - fp: 102.0000 - tn: 304554.0000 - fn: 2304.0000 - accuracy: 0.9921 - precision: 0.7018 - recall: 0.0943 - auc: 0.9731 - mae: 0.0105 - val_loss: 0.0318 - val_tp: 447.0000 - val_fp: 176.0000 - val_tn: 593340.0000 - val_fn: 6037.0000 - val_accuracy: 0.9896 - val_precision: 0.7175 - val_recall: 0.0689 - val_auc: 0.9719 - val_mae: 0.0129
titleDict = {'tp': "True Positives",'fp': "False Positives",'tn': "True Negatives",'fn': "False Negatives", 'accuracy':"BinaryAccuracy",'precision': "Precision",'recall':"Recall",'auc': "Area under Curve", 'mae': "Mean absolute error"}
fig,ax = plt.subplots(2,5, figsize=(20,8), dpi=300)
ax =ax.ravel()
for idx,key in enumerate(titleDict.keys()):
ax[idx].plot(loss_history.epoch, loss_history.history[key], color='coral', label='Training')
ax[idx].plot(loss_history.epoch, loss_history.history['val_'+key], color='cornflowerblue', label='Validation')
ax[idx].set_title(titleDict[key]);
ax[9].axis('off');
axLine, axLabel = ax[0].get_legend_handles_labels() # Take the lables and plot line information from the first panel
lines =[]; labels = []; lines.extend(axLine); labels.extend(axLabel);fig.legend(lines, labels, bbox_to_anchor=(0.7, 0.3), loc='upper left');
unet_train_pred = t_unet.predict(prep_img(train_img[:,wpos[1]:(wpos[1]+ww)]))[0, :, :, 0]
fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs= m_axs.ravel();
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(train_img[:,wpos[1]:(wpos[1]+ww)], cmap='bone', vmin=0, vmax=4000), m_axs[0].set_title('Train Image')
m_axs[1].imshow(unet_train_pred, cmap='viridis', vmin=0, vmax=0.2), m_axs[1].set_title('Predicted Training')
m_axs[2].imshow(train_mask[:,wpos[1]:(wpos[1]+ww)], cmap='viridis'), m_axs[2].set_title('Train Mask');
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs = m_axs.ravel() ;
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(forigc, cmap='bone', vmin=0, vmax=4000); m_axs[0].set_title('Full Image')
f1=m_axs[1].imshow(unet_pred, cmap='viridis', vmin=0, vmax=0.1); m_axs[1].set_title('Predicted Segmentation'); fig.colorbar(f1,ax=m_axs[1]);
m_axs[2].imshow(maskc,cmap='viridis'); m_axs[2].set_title('Ground Truth');
fig, ax = plt.subplots(1,2, figsize=(12,4))
ax0=ax[0].imshow(unet_pred, vmin=0, vmax=0.1); ax[0].set_title('Predicted segmentation'); fig.colorbar(ax0,ax=ax[0])
ax[1].imshow(0.05<unet_pred), ax[1].set_title('Final segmenation');
gt = maskc
pr = 0.05<unet_pred
ps.showHitCases(gt,pr,cmap='gray')
fig, ax = plt.subplots(1,2,figsize=(12,4))
ps.showHitMap(gt,pr,ax=ax)